{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import json\n",
    "import zipfile\n",
    "import sqlite3\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def file_lines():\n",
    "    with zipfile.ZipFile('dgk_shooter_min.conv.zip', 'r') as archive:\n",
    "        with archive.open('dgk_shooter_min.conv', 'r') as fp:\n",
    "            b = fp.read()\n",
    "    content = b.decode('utf8', 'ignore')\n",
    "    lines = []\n",
    "    for line in tqdm(content.split('\\n')):\n",
    "        try:\n",
    "            line = line.replace('\\n', '').strip()\n",
    "            if line.startswith('E'):\n",
    "                lines.append('')\n",
    "            elif line.startswith('M '):\n",
    "                chars = line[2:].split('/')\n",
    "                while len(chars) and chars[len(chars) - 1] == '.':\n",
    "                    chars.pop()\n",
    "                if chars:\n",
    "                    sentence = ''.join(chars)\n",
    "                    sentence = re.sub('\\s+', '，', sentence)\n",
    "                    lines.append(sentence)\n",
    "        except:\n",
    "            print(line)\n",
    "            return lines\n",
    "            lines.append('')\n",
    "    return lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4268085/4268085 [00:10<00:00, 404549.97it/s]\n"
     ]
    }
   ],
   "source": [
    "lines = file_lines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4268080\n"
     ]
    }
   ],
   "source": [
    "print(len(lines))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['',\n",
       " '畹华吾侄',\n",
       " '你接到这封信的时候',\n",
       " '不知道大伯还在不在人世了',\n",
       " '',\n",
       " '咱们梅家从你爷爷起',\n",
       " '就一直小心翼翼地唱戏',\n",
       " '侍奉宫廷侍奉百姓',\n",
       " '从来不曾遭此大祸',\n",
       " '太后的万寿节谁敢不穿红']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lines[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "db = 'conversation.db'\n",
    "if os.path.exists(db):\n",
    "    os.remove(db)\n",
    "conn = sqlite3.connect(db)\n",
    "cur = conn.cursor()\n",
    "cur.execute(\"\"\"\n",
    "    CREATE TABLE IF NOT EXISTS conversation\n",
    "    (ask text, answer text)\n",
    "    \"\"\")\n",
    "conn.commit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def contain_chinese(s):\n",
    "    if re.findall('[\\u4e00-\\u9fa5]+', s):\n",
    "        return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def combine(a, b):\n",
    "    if len(a) > 0 and len(b) > 0:\n",
    "        chinese_punctuation = '。？！，、；：“”‘’（）—'\n",
    "        english_punctuation = '.,;:!\\'\"-[](){}…<>/'\n",
    "        if a[-1:] not in chinese_punctuation and a[-1:] not in english_punctuation:\n",
    "            return a + '，' + b\n",
    "        else:\n",
    "            return a + b\n",
    "    if len(a) > 0 and len(b) <= 0:\n",
    "        return a\n",
    "    if len(a) <= 0 and len(b) > 0:\n",
    "        return b\n",
    "    return ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def valid(a, max_len=0):\n",
    "    if len(a) > 0 and contain_chinese(a):\n",
    "        if max_len <= 0:\n",
    "            return True\n",
    "        elif len(a) <= max_len:\n",
    "            return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def insert(a, b):\n",
    "    cur.execute(\"\"\"INSERT INTO conversation VALUES\n",
    "            ('{}', '{}')\n",
    "            \"\"\".format(a.replace(\"'\", \"''\"), b.replace(\"'\", \"''\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Having more data usually don't hurt. Training on a bigger corpus should be beneficial. OpenSubtitles seems the biggest for now. Another trick to artificially increase the dataset size when creating the corpus could be to split the sentences of each training sample (ex: from the sample Q:Sentence 1. Sentence 2. => A:Sentence X. Sentence Y. we could generate 3 new samples: Q:Sentence 1. Sentence 2. => A:Sentence X., Q:Sentence 2. => A:Sentence X. Sentence Y. and Q:Sentence 2. => A:Sentence X.. Warning: other combinations like Q:Sentence 1. => A:Sentence X. won't work because it would break the transition 2 => X which links the question to the answer)\n",
    "\n",
    "https://github.com/Conchylicultor/DeepQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "4268080it [01:50, 38577.75it/s]\n"
     ]
    }
   ],
   "source": [
    "words = Counter()\n",
    "keys = {}\n",
    "a = ''\n",
    "b = ''\n",
    "c = ''\n",
    "d = ''\n",
    "inserted = 0\n",
    "\n",
    "def insert_if(question, answer, input_len=12, output_len=20):\n",
    "    k = question + '----------' + answer\n",
    "    if valid(question, input_len) and valid(answer, output_len) and k not in keys:\n",
    "        insert(question, answer)\n",
    "        keys[k] = 1\n",
    "        return 1\n",
    "    return 0\n",
    "        \n",
    "for index, line in tqdm(enumerate(lines)):\n",
    "    words.update(Counter(line))\n",
    "    # <<<\n",
    "    a = b\n",
    "    b = c\n",
    "    c = d\n",
    "    d = line\n",
    "    \"\"\"\n",
    "    # a -> b\n",
    "    question = a\n",
    "    answer = b\n",
    "    inserted += insert_if(question, answer)\n",
    "    \"\"\"\n",
    "    # a, b -> c\n",
    "    question = combine(a, b)\n",
    "    answer = c\n",
    "    inserted += insert_if(question, answer)\n",
    "    # b -> c, d\n",
    "    question = b\n",
    "    answer = combine(c, d)\n",
    "    inserted += insert_if(question, answer)\n",
    "    # b -> c\n",
    "    question = b\n",
    "    answer = c\n",
    "    inserted += insert_if(question, answer)\n",
    "    # commit batch\n",
    "    if inserted != 0 and inserted % 50000 == 0:\n",
    "        conn.commit()\n",
    "conn.commit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
